{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create entry points to spark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark import SparkContext\n",
    "sc = SparkContext(master = 'local')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "spark = SparkSession.builder \\\n",
    "          .appName(\"Python Spark SQL basic example\") \\\n",
    "          .config(\"spark.some.config.option\", \"some-value\") \\\n",
    "          .getOrCreate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load iris data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "iris = spark.read.csv('data/iris.csv', header=True, inferSchema=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----------+------------+-----------+-------+\n",
      "|sepal_length|sepal_width|petal_length|petal_width|species|\n",
      "+------------+-----------+------------+-----------+-------+\n",
      "|         5.1|        3.5|         1.4|        0.2| setosa|\n",
      "|         4.9|        3.0|         1.4|        0.2| setosa|\n",
      "|         4.7|        3.2|         1.3|        0.2| setosa|\n",
      "|         4.6|        3.1|         1.5|        0.2| setosa|\n",
      "|         5.0|        3.6|         1.4|        0.2| setosa|\n",
      "+------------+-----------+------------+-----------+-------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iris.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('sepal_length', 'double'),\n",
       " ('sepal_width', 'double'),\n",
       " ('petal_length', 'double'),\n",
       " ('petal_width', 'double'),\n",
       " ('species', 'string')]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iris.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+-------------------+------------------+------------------+---------+\n",
      "|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|  species|\n",
      "+-------+------------------+-------------------+------------------+------------------+---------+\n",
      "|  count|               150|                150|               150|               150|      150|\n",
      "|   mean| 5.843333333333335| 3.0540000000000007|3.7586666666666693|1.1986666666666672|     null|\n",
      "| stddev|0.8280661279778637|0.43359431136217375| 1.764420419952262|0.7631607417008414|     null|\n",
      "|    min|               4.3|                2.0|               1.0|               0.1|   setosa|\n",
      "|    max|               7.9|                4.4|               6.9|               2.5|virginica|\n",
      "+-------+------------------+-------------------+------------------+------------------+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iris.describe().show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge features to create a features column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.linalg import Vectors\n",
    "from pyspark.sql import Row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+-------+\n",
      "|         features|species|\n",
      "+-----------------+-------+\n",
      "|[5.1,3.5,1.4,0.2]| setosa|\n",
      "|[4.9,3.0,1.4,0.2]| setosa|\n",
      "|[4.7,3.2,1.3,0.2]| setosa|\n",
      "|[4.6,3.1,1.5,0.2]| setosa|\n",
      "|[5.0,3.6,1.4,0.2]| setosa|\n",
      "+-----------------+-------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iris2 = iris.rdd.map(lambda x: Row(features=Vectors.dense(x[:-1]), species=x[-1])).toDF()\n",
    "iris2.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Index label column with `StringIndexer`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "from pyspark.ml import Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build pipeline\n",
    "Try to use pipeline whenever you can to get used to this format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "stringindexer = StringIndexer(inputCol='species', outputCol='label')\n",
    "stages = [stringindexer]\n",
    "pipeline = Pipeline(stages=stages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transform data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+-------+-----+\n",
      "|         features|species|label|\n",
      "+-----------------+-------+-----+\n",
      "|[5.1,3.5,1.4,0.2]| setosa|  2.0|\n",
      "|[4.9,3.0,1.4,0.2]| setosa|  2.0|\n",
      "|[4.7,3.2,1.3,0.2]| setosa|  2.0|\n",
      "|[4.6,3.1,1.5,0.2]| setosa|  2.0|\n",
      "|[5.0,3.6,1.4,0.2]| setosa|  2.0|\n",
      "+-----------------+-------+-----+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iris_df = pipeline.fit(iris2).transform(iris2)\n",
    "iris_df.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the data one more time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+---------+------------------+\n",
      "|summary|  species|             label|\n",
      "+-------+---------+------------------+\n",
      "|  count|      150|               150|\n",
      "|   mean|     null|               1.0|\n",
      "| stddev|     null|0.8192319205190403|\n",
      "|    min|   setosa|               0.0|\n",
      "|    max|virginica|               2.0|\n",
      "+-------+---------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "iris_df.describe().show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('features', 'vector'), ('species', 'string'), ('label', 'double')]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iris_df.dtypes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Naive Bayes classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split data into training and test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train, test = iris_df.randomSplit([0.8, 0.2], seed=1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build cross-validation model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import NaiveBayes\n",
    "naivebayes = NaiveBayes(featuresCol=\"features\", labelCol=\"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Parameter grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder\n",
    "param_grid = ParamGridBuilder().\\\n",
    "    addGrid(naivebayes.smoothing, [0, 1, 2, 4, 8]).\\\n",
    "    build()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluator\n",
    "There are three categories in the label column. Therefore, we use `MulticlassClassificationEvaluator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "evaluator = MulticlassClassificationEvaluator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build cross-validation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import CrossValidator\n",
    "crossvalidator = CrossValidator(estimator=naivebayes, estimatorParamMaps=param_grid, evaluator=evaluator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fit cross-validation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "crossvalidation_mode = crossvalidator.fit(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prediction on training and test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "|         features|species|label|       rawPrediction|         probability|prediction|\n",
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "|[4.4,3.2,1.3,0.2]| setosa|  2.0|[-12.271684174579...|[0.19444573537358...|       2.0|\n",
      "|[4.5,2.3,1.3,0.3]| setosa|  2.0|[-11.119812088601...|[0.27298108810790...|       2.0|\n",
      "|[4.6,3.1,1.5,0.2]| setosa|  2.0|[-12.527795092630...|[0.21626072148339...|       2.0|\n",
      "|[4.6,3.2,1.4,0.2]| setosa|  2.0|[-12.570525272901...|[0.19958486427181...|       2.0|\n",
      "|[4.6,3.4,1.4,0.3]| setosa|  2.0|[-13.131900980855...|[0.19913306361405...|       2.0|\n",
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_train = crossvalidation_mode.transform(train)\n",
    "pred_train.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "|         features|species|label|       rawPrediction|         probability|prediction|\n",
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "|[4.3,3.0,1.1,0.1]| setosa|  2.0|[-11.379239696581...|[0.17893625340183...|       2.0|\n",
      "|[4.4,2.9,1.4,0.2]| setosa|  2.0|[-11.901296005920...|[0.22552241734351...|       2.0|\n",
      "|[4.4,3.0,1.3,0.2]| setosa|  2.0|[-11.944026186191...|[0.20863892696497...|       2.0|\n",
      "|[4.8,3.1,1.6,0.2]| setosa|  2.0|[-12.826636190952...|[0.22170131985173...|       2.0|\n",
      "|[5.0,3.3,1.4,0.2]| setosa|  2.0|[-13.089838835897...|[0.18446024811785...|       2.0|\n",
      "+-----------------+-------+-----+--------------------+--------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_test = crossvalidation_mode.transform(test)\n",
    "pred_test.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Best model from cross validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The parameter smoothing has best value: 4.0\n"
     ]
    }
   ],
   "source": [
    "print(\"The parameter smoothing has best value:\",\n",
    "      crossvalidation_mode.bestModel._java_obj.getSmoothing())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prediction accurary\n",
    "Four accuracy matrices are avaiable for this evaluator. \n",
    "* `f1`\n",
    "* `weightedPrecision`\n",
    "* `weightedRecall`\n",
    "* `accuracy`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Prediction accuracy on training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training data (f1): 0.9682539682539681 \n",
      " training data (weightedPrecision):  0.9682539682539681 \n",
      " training data (weightedRecall):  0.9682539682539681 \n",
      " training data (accuracy):  0.9682539682539683\n"
     ]
    }
   ],
   "source": [
    "print('training data (f1):', evaluator.setMetricName('f1').evaluate(pred_train), \"\\n\",\n",
    "     'training data (weightedPrecision): ', evaluator.setMetricName('weightedPrecision').evaluate(pred_train),\"\\n\",\n",
    "     'training data (weightedRecall): ', evaluator.setMetricName('weightedRecall').evaluate(pred_train),\"\\n\",\n",
    "     'training data (accuracy): ', evaluator.setMetricName('accuracy').evaluate(pred_train))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Prediction accuracy on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test data (f1): 0.958119658119658 \n",
      " test data (weightedPrecision):  0.9635416666666667 \n",
      " test data (weightedRecall):  0.9583333333333334 \n",
      " test data (accuracy):  0.9583333333333334\n"
     ]
    }
   ],
   "source": [
    "print('test data (f1):', evaluator.setMetricName('f1').evaluate(pred_test), \"\\n\",\n",
    "     'test data (weightedPrecision): ', evaluator.setMetricName('weightedPrecision').evaluate(pred_test),\"\\n\",\n",
    "     'test data (weightedRecall): ', evaluator.setMetricName('weightedRecall').evaluate(pred_test),\"\\n\",\n",
    "     'test data (accuracy): ', evaluator.setMetricName('accuracy').evaluate(pred_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Confusion matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Confusion matrix on training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {Row(label=0.0, prediction=0.0): 41,\n",
       "             Row(label=0.0, prediction=1.0): 2,\n",
       "             Row(label=1.0, prediction=0.0): 2,\n",
       "             Row(label=1.0, prediction=1.0): 41,\n",
       "             Row(label=2.0, prediction=2.0): 40})"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_conf_mat = pred_train.select('label', 'prediction')\n",
    "train_conf_mat.rdd.zipWithIndex().countByKey()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Confusion matrix on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {Row(label=0.0, prediction=0.0): 7,\n",
       "             Row(label=1.0, prediction=0.0): 1,\n",
       "             Row(label=1.0, prediction=1.0): 6,\n",
       "             Row(label=2.0, prediction=2.0): 10})"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_conf_mat = pred_test.select('label', 'prediction')\n",
    "test_conf_mat.rdd.zipWithIndex().countByKey()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**From the confusion matrices on both training and test data, we can see that there are only a few mismatches between prediction and label values.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
